# validate_adjoint_suite.py
import sys, yaml, pandas as pd, numpy as np
from collections import defaultdict

CFG = yaml.safe_load(open("configs/default.yaml","r",encoding="utf-8"))
CSV = "data/results/vol4_wilson_loop_adjoint_volume_sweep/adjoint_volume_summary.csv"

def get(cfg, *cands, required=True, default=None):
    for path in cands:
        node = cfg
        ok = True
        for k in path.split("."):
            if isinstance(node, dict) and k in node:
                node = node[k]
            else:
                ok = False; break
        if ok: return node
    if required: raise KeyError(f"None of {cands} found in config")
    return default

def main():
    adj = pd.read_csv(CSV)
    # Expected grid size from YAML (fallback to CSV uniques if keys absent)
    try:
        Ls     = get(CFG, "adjoint_volume.volumes", "L_values", "L_list")
        gauges = get(CFG, "adjoint_volume.gauge_groups", "adjoint_volume.gauges")
        bs     = get(CFG, "b_values", "adjoint_volume.b_values")
        ks     = get(CFG, "k_values", "adjoint_volume.k_values")
        n0s    = get(CFG, "n0_values", "adjoint_volume.n0_values")
        expected = len(Ls)*len(gauges)*len(bs)*len(ks)*len(n0s)
    except Exception:
        expected = adj.drop_duplicates(subset=["L","b","gauge","k","n0"]).shape[0]

    uniq = adj.drop_duplicates(subset=["L","b","gauge","k","n0"]).shape[0]
    print(f"grid rows: {len(adj)}  unique points: {uniq}  expected: {expected}")
    ok = (uniq == expected)

    # Kernel consistency: single hash per (L,gauge)
    bad = []
    for (L,g), grp in adj.groupby(["L","gauge"]):
        hashes = set(h for h in grp["kernel_sha256"].dropna().unique())
        if len(hashes) > 1:
            bad.append(((L,g), list(hashes)))
    if bad:
        print("\n[FAIL] Multiple kernel hashes for the same (L,gauge):")
        for (L,g), hs in bad:
            print(f"  (L={L}, gauge={g}) → {hs}")
        ok = False
    else:
        print("[OK] Kernel hash consistent within each (L,gauge).")

    # flipcount_n sanity: expect 2*L^2
    adj["_expect_n"] = 2*(adj["L"]**2)
    mism = adj.loc[(adj["flipcount_n"]>0) & (adj["flipcount_n"] != adj["_expect_n"])]
    if not mism.empty:
        print("\n[FAIL] flipcount_n != 2*L^2 for some rows:")
        print(mism[["L","gauge","b","k","n0","flipcount_n","_expect_n"]].head(10))
        ok = False
    else:
        print("[OK] flipcount_n matches 2*L^2 (or not set).")

    # Error sanity
    share_tiny = float((adj["string_tension_err"] < 1e-12).mean())
    print(f"error < 1e-12 share: {share_tiny:.3f}")
    if share_tiny > 0.25:
        print("[WARN] Many extremely small errors; consider larger bootstrap_block or reps.")

    print("\nPer-(L,gauge) mean±std err:")
    print(adj.groupby(["L","gauge"])["string_tension_err"].agg(["mean","std","count"]).head(12))

    sys.exit(0 if ok else 2)

if __name__ == "__main__":
    main()
